(** * Identifiers and freshness frames

    Each code generation creates a "freshness frame": the interval (for each
    kind of variable) between the initial [IRState] and the resulting one after
    generation of the code.

    We will need to reason about which frame different variables belong to.
    This file provide some generic tools to this end that will be used to 
    derive more specific facts in respectively [LidBound] and [BidBound]. 

    Important definition:
    - [state_bound s id] : the variable [id] is "below" [s], i.e. [s] will never be
      able to generate [s] from now on.
    - [state_bound_between s1 s2 id] : the variable [id] has been generated by a state
      situated in the frame [s1;s2].

    Note: in the proof of the Helix compiler, only [state_bound_between] is used.
    Indeed, [state_bound] is too coarse an invariant to support reasoning about the
    sequence due to the subcomponent being compiled in the reversed order they are
    generated.

 *)

Require Import Helix.LLVMGen.Correctness_Prelude.
Require Import Helix.LLVMGen.IdLemmas.

Import ListNotations.

Set Implicit Arguments.
Set Strict Implicit.

Global Opaque resolve_PVar.

(** Reasoning about when identifiers are bound in certain states *)
Section StateBound.
  Variable count :  IRState -> nat.
  Variable gen : string -> cerr raw_id.

  (* TODO: Injective is sort of a lie... Is there a better thing to call this? *)
  Definition count_gen_injective : Prop
    := forall s1 s1' s2 s2' name1 name2 id1 id2,
      inr (s1', id1) ≡ gen name1 s1 ->
      inr (s2', id2) ≡ gen name2 s2 ->
      count s1 ≢ count s2 ->
      is_correct_prefix name1 ->
      is_correct_prefix name2 ->
      id1 ≢ id2.

  Definition count_gen_mono : Prop
    := forall s1 s2 name id,
      inr (s2, id) ≡ gen name s1 ->
      (count s2 > count s1)%nat.

  Variable INJ : count_gen_injective.
  Variable MONO : count_gen_mono.

  (* Says whether or not a variable has been generated by an earlier IRState,

     I.e., this holds when `id` can be generated using `gen` from a
     state with an earlier counter. The intuition is that `id` ends
     with a number that is *lower* than the count for the current
     state.
   *)
  Definition state_bound (s : IRState) (id : raw_id) : Prop
    := exists name s' s'',
      is_correct_prefix name /\
      (count s' < count s)%nat /\
      inr (s'', id) ≡ gen name s'.

  Definition state_bound_between (s1 s2 : IRState) (id : raw_id) : Prop
    := exists name s' s'',
      is_correct_prefix name /\
      (count s' < count s2)%nat /\
      count s' ≥ count s1 /\
      inr (s'', id) ≡ gen name s'.

  Lemma state_bound_fresh :
    forall (s1 s2 : IRState) (id id' : raw_id),
      state_bound s1 id ->
      state_bound_between s1 s2 id' ->
      id ≢ id'.
  Proof.
    intros s1 s2 id id' BOUND BETWEEN.
    destruct BOUND as (n1 & s1' & s1'' & N_S1 & COUNT_S1 & GEN_id).
    destruct BETWEEN as (n2 & sm' & sm'' & N_S2 & COUNT_Sm_ge & COUNT_Sm_lt & GEN_id').

    eapply INJ.
    apply GEN_id.
    apply GEN_id'.
    lia.
    all: auto.
  Qed.

  Lemma state_bound_fresh' :
    forall (s1 s2 s3 : IRState) (id id' : raw_id),
      state_bound s1 id ->
      (count s1 <= count s2)%nat ->
      state_bound_between s2 s3 id' ->
      id ≢ id'.
  Proof.
    intros s1 s2 s3 id id' BOUND COUNT BETWEEN.
    destruct BOUND as (n1 & s1' & s1'' & N_S1 & COUNT_S1 & GEN_id).
    destruct BETWEEN as (n2 & sm' & sm'' & N_S2 & COUNT_Sm_ge & COUNT_Sm_lt & GEN_id').

    eapply INJ.
    apply GEN_id.
    apply GEN_id'.
    lia.
    all: auto.
  Qed.

  Lemma state_bound_bound_between :
    forall (s1 s2 : IRState) (bid : block_id),
      state_bound s2 bid ->
      ~(state_bound s1 bid) ->
      state_bound_between s1 s2 bid.
  Proof.
    intros s1 s2 bid BOUND NOTBOUND.
    destruct BOUND as (n1 & s1' & s1'' & N_S1 & COUNT_S1 & GEN_bid).
    unfold state_bound_between.
    exists n1. exists s1'. exists s1''.
    repeat (split; auto).
    pose proof (NatUtil.lt_ge_dec (count s1') (count s1)) as [LT | GE].
    - (* If this is the case, I must have a contradiction, which would mean that
         bid_bound s1 bid... *)
      assert (state_bound s1 bid).
      unfold state_bound.
      exists n1. exists s1'. exists s1''.
      auto.
      contradiction.
    - auto.
  Qed.

  Lemma state_bound_before_not_bound_between :
    forall (s s1 s2 : IRState) (bid : block_id),
      state_bound s bid ->
      (count s <= count s1)%nat ->
      ~ (state_bound_between s1 s2 bid).
  Proof.
    intros s s1 s2 bid BOUND COUNT.
    intros BETWEEN.

    unfold state_bound in BOUND.
    unfold state_bound_between in BETWEEN.

    destruct BOUND as (bname & bs1 & bs2 & bpref & bcount & bgen).
    destruct BETWEEN as (wname & ws1 & ws2 & wpref & wcount1 & wcount2 & wgen).

    eapply INJ.
    eapply bgen.
    eapply wgen.
    all: eauto.
    intros CONTRA.
    lia.
  Qed.

  Lemma state_bound_mono :
    forall s1 s2 bid,
      state_bound s1 bid ->
      (count s1 <= count s2)%nat ->
      state_bound s2 bid.
  Proof.
    intros s1 s2 bid BOUND COUNT.
    destruct BOUND as (n1 & s1' & s1'' & N_S1 & COUNT_S1 & GEN_bid).
    unfold state_bound.
    exists n1. exists s1'. exists s1''.
    intuition.
  Qed.

  Lemma state_bound_between_shrink :
    forall s1 s2 s1' s2' id,
      state_bound_between s1 s2 id ->
      (count s1' <= count s1)%nat ->
      (count s2' >= count s2)%nat ->
      state_bound_between s1' s2' id.
  Proof.
    intros s1 s2 s1' s2' id BOUND_BETWEEN S1LE S2GE.
    unfold state_bound_between.
    destruct BOUND_BETWEEN as (n & s' & s'' & NEND & LT & GE & INC).
    exists n. exists s'. exists s''.
    repeat (split; auto).
    all: lia.
  Qed.

  Lemma all_state_bound_between_shrink :
    forall s1 s2 s1' s2' ids,
      Forall (state_bound_between s1 s2) ids ->
      (count s1' <= count s1)%nat ->
      (count s2' >= count s2)%nat ->
      Forall (state_bound_between s1' s2') ids.
  Proof.
    intros s1 s2 s1' s2' bids BOUND_BETWEEN S1LE S2GE.
    apply Forall_forall.
    intros x IN.
    eapply Forall_forall in BOUND_BETWEEN; eauto.
    eapply state_bound_between_shrink; eauto.
  Qed.
  
  Lemma state_bound_between_separate :
    forall s1 s2 s3 s4 id id',
      state_bound_between s1 s2 id ->
      state_bound_between s3 s4 id' ->
      (count s2 <= count s3)%nat ->
      id ≢ id'.
  Proof.
    intros s1 s2 s3 s4 id id' BOUND1 BOUND2 BC.
    destruct BOUND1 as (n1 & s1' & s1'' & NEND1 & LT1 & GE1 & INC1).
    destruct BOUND2 as (n2 & s2' & s2'' & NEND2 & LT2 & GE2 & INC2).

    assert (count s1' ≢ count s2') as NEQ by lia.
    eapply INJ.
    apply INC1.
    apply INC2.
    all: eauto.
  Qed.

  Lemma state_bound_between_id_separate :
    forall s1 s2 s3 s4 id,
      state_bound_between s1 s2 id ->
      state_bound_between s3 s4 id ->
      (count s2 <= count s3)%nat ->
      False.
  Proof.
    intros s1 s2 s3 s4 id BOUND1 BOUND2 BC.
    eapply (state_bound_between_separate BOUND1 BOUND2); auto.
  Qed.

  Lemma not_state_bound_between_split :
    forall (s1 s2 s3 : IRState) id,
      ~ state_bound_between s1 s2 id ->
      ~ state_bound_between s2 s3 id ->
      ~ state_bound_between s1 s3 id.
  Proof.
    intros s1 s2 s3 id S1S2 S2S3.
    intros BOUND.
    unfold state_bound_between in BOUND.
    destruct BOUND as (name & s' & s'' & NEND & COUNT1 & COUNT2 & GEN).
    assert (count s' < count s2 \/ count s' >= count s2)%nat as COUNT_MID by lia.
    destruct COUNT_MID as [COUNT_MID | COUNT_MID].
    - apply S1S2.
      unfold state_bound_between.
      exists name. exists s'. exists s''.
      auto.
    - apply S2S3.
      unfold state_bound_between.
      exists name. exists s'. exists s''.
      auto.
  Qed.

  Lemma gen_not_state_bound :
    forall name s1 s2 id,
      is_correct_prefix name ->
      gen name s1 ≡ inr (s2, id) ->
      ~(state_bound s1 id).
  Proof.
    intros name s1 s2 id ENDS INC.
    intros BOUND.
    destruct BOUND as (n1 & s1' & s1'' & N_S1 & COUNT_S1 & GEN_id).
    symmetry in INC.

    eapply (INJ INC GEN_id); auto.
    lia.
  Qed.

 Lemma gen_state_bound :
    forall name s1 s2 id,
      is_correct_prefix name ->
      gen name s1 ≡ inr (s2, id) ->
      state_bound s2 id.
  Proof.
    intros name s1 s2 id ENDS INC.
    exists name. exists s1. exists s2.
    repeat (split; auto).
    eapply MONO; eauto.
  Qed.

  Lemma gen_state_bound_between :
    forall name s1 s2 id,
      is_correct_prefix name ->
      gen name s1 ≡ inr (s2, id) ->
      state_bound_between s1 s2 id.
  Proof.
    intros name s1 s2 id NEND GEN.
    apply state_bound_bound_between.
    - eapply gen_state_bound; eauto.
    - eapply gen_not_state_bound; eauto.
  Qed.

  Lemma not_id_bound_gen_mono :
    forall name s1 s2 s' id,
      gen name s1 ≡ inr (s2, id) ->
      (count s' <= count s1)%nat ->
      is_correct_prefix name ->
      ~ (state_bound s' id).
  Proof.
    intros name s1 s2 s' id INC LE NE.
    intros BOUND.
    destruct BOUND as (n1 & s1' & s1'' & N_S1 & COUNT_S1 & GEN_id).
    assert (count s1 ≢ count s1') as NEQ by lia.
    eapply INJ.
    symmetry; apply INC.
    apply GEN_id.
    all: auto.
  Qed.

  Lemma state_bound_between_disjoint_neq :
    forall x y s1 s2 s3 s4,
      state_bound_between s1 s2 x ->
      state_bound_between s3 s4 y ->
      (count s2 <= count s3)%nat ->
      x ≢ y.
  Proof.
    intros x y s1 s2 s3 s4 BOUND1 BOUND2 COUNT.

    destruct BOUND1 as (name1 & s1' & s1'' & PREF & COUNT1 & COUNT2 & GEN).
    destruct BOUND2 as (name2 & s2' & s2'' & PREF' & COUNT1' & COUNT2' & GEN').

    eapply INJ; eauto.
    lia.
  Qed.

  Lemma state_bound_between_list_disjoint :
    forall l1 l2 s1 s2 s3 s4,
      Forall (state_bound_between s1 s2) l1 ->
      Forall (state_bound_between s3 s4) l2 ->
      (count s2 <= count s3)%nat ->
      Coqlib.list_disjoint l1 l2.
  Proof.
    intros l1 l2 s1 s2 s3 s4 BOUND1 BOUND2 COUNT.

    unfold Coqlib.list_disjoint.
    intros x y IN1 IN2.

    eapply Forall_forall in BOUND1; eauto.
    eapply Forall_forall in BOUND2; eauto.

    eapply state_bound_between_disjoint_neq; eauto.
  Qed.

  Lemma state_bound_between_disjoint_norepet :
    forall l1 l2 s1 s2 s3 s4,
      Coqlib.list_norepet l1 ->
      Coqlib.list_norepet l2 ->
      Forall (state_bound_between s1 s2) l1 ->
      Forall (state_bound_between s3 s4) l2 ->
      (count s2 <= count s3)%nat ->
      Coqlib.list_norepet (l1 ++ l2).
  Proof.
    intros l1 l2 s1 s2 s3 s4 NR1 NR2 BOUND1 BOUND2 COUNT.
    apply Coqlib.list_norepet_append; eauto.

    eapply state_bound_between_list_disjoint; eauto.
  Qed.

  Lemma state_bound_before_bound_between :
    forall s1 s2 id,
      state_bound s1 id ->
      count s1 <= count s2 ->
      exists s0,
        state_bound_between s0 s2 id.
  Proof.
    intros s1 s2 id BOUND LT.
    destruct BOUND as (prefix & s1' & s2' & PRE & COUNT & GEN).
    exists s1'.
    do 3 eexists.
    repeat split; eauto.
    lia.
  Qed.

End StateBound.
